import math
import os
import cv2
import numpy as np
import skimage.morphology
from PIL import Image
from torchvision import transforms

from envs.utils.fmm_planner import FMMPlanner
from envs.habitat.objectgoal_env import ObjectGoal_Env
from constants import color_palette
import envs.utils.pose as pu
import agents.utils.visualization as vu
from agents.utils.rednet import SemanticPredRedNet, draw_sem_img
from agents.utils.semantic_prediction import SemanticPredMaskRCNN

import imageio

visualize_img_list = []

class UnTrapHelper:
    def __init__(self):
        self.total_id = 0
        self.epi_id = 0

    def reset(self):
        self.total_id += 1
        self.epi_id = 0

    def get_action(self):
        self.epi_id += 1
        if self.epi_id == 1:
            if self.total_id % 2 == 0:
                return 2
            else:
                return 3 #3
        else:
            if self.total_id % 2 == 0:
                return 3 #3
            else:
                return 2

class Sem_Exp_Env_Agent(ObjectGoal_Env):
    """The Sem_Exp environment agent class. A seperate Sem_Exp_Env_Agent class
    object is used for each environment thread.

    """

    def __init__(self, args, rank, config_env, dataset):

        self.args = args
        super().__init__(args, rank, config_env, dataset)

        # initialize transform for RGB observations
        self.res = transforms.Compose(
            [transforms.ToPILImage(),
             transforms.Resize((args.frame_height, args.frame_width),
                               interpolation=Image.NEAREST)])

        # initialize semantic segmentation prediction model
        # print()

        # allocate multiple gpus for sem models
        sem_gpu_visible_devices = [int(str_gpu_id) for str_gpu_id in args.sem_gpu_id_list.split(",")]
        sem_gpu_id = sem_gpu_visible_devices[int(int(self.rank)%len(sem_gpu_visible_devices))] 
        args.sem_gpu_id = "cuda:" + str(sem_gpu_id)
        
        if args.sem_gpu_id == -1:
            args.sem_gpu_id = config_env.SIMULATOR.HABITAT_SIM_V0.GPU_DEVICE_ID

        if args.backbone_2d == "rednet":
            self.sem_pred_model = SemanticPredRedNet(args)
        if args.backbone_2d == "maskrcnn":
            self.sem_pred_model = SemanticPredMaskRCNN(args)

        # initializations for planning:
        self.selem = skimage.morphology.disk(3)

        self.obs = None
        self.obs_ = None
        self.obs_shape = None
        self.collision_map = None
        self.visited = None
        self.visited_vis = None
        self.col_width = None
        self.curr_loc = None
        self.last_loc = None
        self.last_action = None
        self.count_forward_actions = None

        self.untrap = UnTrapHelper()
        self.block_threshold = 3

        if args.visualize or args.print_images:
            self.legend = cv2.imread('docs/legend.png')
            self.vis_image = None
            self.rgb_vis = None

    def reset(self):
        args = self.args

        obs, info = super().reset()
        obs, obs_ = self._preprocess_obs(obs, info['goal_cat_id'])

        self.obs_shape = obs.shape

        # Episode initializations
        map_shape = (args.map_size_cm // args.map_resolution,
                     args.map_size_cm // args.map_resolution)
        self.collision_map = np.zeros(map_shape)
        self.visited = np.zeros(map_shape)
        self.visited_vis = np.zeros(map_shape)
        self.col_width = 1
        self.count_forward_actions = 0
        self.curr_loc = [args.map_size_cm / 100.0 / 2.0,
                         args.map_size_cm / 100.0 / 2.0, 0.]
        self.last_action = None

        if args.visualize or args.print_images:
            self.vis_image = vu.init_vis_image(self.goal_name, self.legend)

        self.info['timestep'] = self.timestep
        self.info['episode_no'] = self.episode_no
        self.info['rank'] = self.rank

        self.block_threshold = 3
        self.prev_blocked = 0
        self._previous_action = -1
        self.untrap = UnTrapHelper()

        return obs, obs_, info

    def verify_action(self):
        action = {'action': 0}
        obs, rew, done, info = super().step(action)
        super().reset()


    def plan_act_and_preprocess(self, planner_inputs):
        """Function responsible for planning, taking the action and
        preprocessing observations

        Args:
            planner_inputs (dict):
                dict with following keys:
                    'map_pred'  (ndarray): (M, M) map prediction
                    'goal'      (ndarray): (M, M) mat denoting goal locations
                    'pose_pred' (ndarray): (7,) array denoting pose (x,y,o)
                                 and planning window (gx1, gx2, gy1, gy2)
                     'found_goal' (bool): whether the goal object is found

        Returns:
            obs (ndarray): preprocessed observations ((4+C) x H x W)
            reward (float): amount of reward returned after previous action
            done (bool): whether the episode has ended
            info (dict): contains timestep, pose, goal category and
                         evaluation metric info
        """


        # plan
        if planner_inputs["wait"]:
            self.last_action = None
            self.info["sensor_pose"] = [0., 0., 0.]
            return np.zeros(self.obs.shape), np.zeros(self.obs_.shape), 0., False, self.info

        if planner_inputs["new_goal"]:
            self.info["g_reward"] = 0

        action, achieve_goal = self._plan(planner_inputs) # FMM plan path to the long-term goal


        if self.args.visualize or self.args.print_images:
            map_img = self._visualize_map(planner_inputs)
            self._visualize_gif(map_img, action)

        self.info['timestep'] = self.timestep
        self.info['episode_no'] = self.episode_no
        self.info['rank'] = self.rank
        self.info['achieve_goal'] = achieve_goal

        if action >= 0:
            # act
            action = {'action': action}
            obs, rew, done, info = super().step(action)
            obs, obs_ = self._preprocess_obs(obs,  info['goal_cat_id'])

            self.last_action = action['action']
            self.obs = obs
            self.obs_ = obs_
            self.info = info

            self.info['g_reward'] += rew

            return obs, obs_, rew, done, self.info

        else:
            self.last_action = None
            self.info["sensor_pose"] = [0., 0., 0.]
            return np.zeros(self.obs_shape), np.zeros(self.obs_), 0., False, self.info



    def _plan(self, planner_inputs):
        """Function responsible for planning

        Args:
            planner_inputs (dict):
                dict with following keys:
                    'map_pred'  (ndarray): (M, M) map prediction
                    'goal'      (ndarray): (M, M) goal locations
                    'pose_pred' (ndarray): (7,) array  denoting pose (x,y,o)
                                 and planning window (gx1, gx2, gy1, gy2)
                    'found_goal' (bool): whether the goal object is found

        Returns:
            action (int): action id
        """
        args = self.args

        self.last_loc = self.curr_loc

        # Get Map prediction
        map_pred = np.rint(planner_inputs['map_pred'])
        goal = planner_inputs['goal']

        # Get pose prediction and global policy planning window
        start_x, start_y, start_o, gx1, gx2, gy1, gy2 = \
            planner_inputs['pose_pred']
        gx1, gx2, gy1, gy2 = int(gx1), int(gx2), int(gy1), int(gy2)
        planning_window = [gx1, gx2, gy1, gy2]

        # Get curr loc
        self.curr_loc = [start_x, start_y, start_o]
        r, c = start_y, start_x
        start = [int(r * 100.0 / args.map_resolution - gx1),
                 int(c * 100.0 / args.map_resolution - gy1)]
        start = pu.threshold_poses(start, map_pred.shape)

        self.visited[gx1:gx2, gy1:gy2][start[0] - 0:start[0] + 1,
                                       start[1] - 0:start[1] + 1] = 1

        if args.visualize or args.print_images:
            # Get last loc
            last_start_x, last_start_y = self.last_loc[0], self.last_loc[1]
            r, c = last_start_y, last_start_x
            last_start = [int(r * 100.0 / args.map_resolution - gx1),
                          int(c * 100.0 / args.map_resolution - gy1)]
            last_start = pu.threshold_poses(last_start, map_pred.shape)
            self.visited_vis[gx1:gx2, gy1:gy2] = \
                vu.draw_line(last_start, start,
                             self.visited_vis[gx1:gx2, gy1:gy2])

        # Collision check
        if self.last_action == 1:
            x1, y1, t1 = self.last_loc
            x2, y2, _ = self.curr_loc
            buf = 4
            length = 2

            if abs(x1 - x2) < 0.05 and abs(y1 - y2) < 0.05:
                self.col_width += 2
                if self.col_width == 7:
                    length = 4
                    buf = 3
                self.col_width = min(self.col_width, 5)
            else:
                self.col_width = 1

            dist = pu.get_l2_distance(x1, x2, y1, y2)
            if dist < args.collision_threshold:  # Collision
                self.prev_blocked += 1
                width = self.col_width
                for i in range(length):
                    for j in range(width):
                        wx = x1 + 0.05 * \
                            ((i + buf) * np.cos(np.deg2rad(t1))
                             + (j - width // 2) * np.sin(np.deg2rad(t1)))
                        wy = y1 + 0.05 * \
                            ((i + buf) * np.sin(np.deg2rad(t1))
                             - (j - width // 2) * np.cos(np.deg2rad(t1)))
                        r, c = wy, wx
                        r, c = int(r * 100 / args.map_resolution), \
                            int(c * 100 / args.map_resolution)
                        [r, c] = pu.threshold_poses([r, c],
                                                    self.collision_map.shape)
                        self.collision_map[r, c] = 1
            else:
                if self.prev_blocked >= self.block_threshold:
                    self.untrap.reset()
                self.prev_blocked = 0

        stg, stop = self._get_stg(map_pred, start, np.copy(goal),
                                  planning_window)
        achieve_goal = False
        if stop:
            achieve_goal = True
            
        # Deterministic Local Policy
        if stop and planner_inputs['found_goal'] == 1:
            action = 0  # Stop
        else:
            (stg_x, stg_y) = stg
            angle_st_goal = math.degrees(math.atan2(stg_x - start[0],
                                                    stg_y - start[1]))
            angle_agent = (start_o) % 360.0
            if angle_agent > 180:
                angle_agent -= 360

            relative_angle = (angle_agent - angle_st_goal) % 360.0
            if relative_angle > 180:
                relative_angle -= 360

            if relative_angle > self.args.turn_angle / 2.:
                action = 3  # Right
            elif relative_angle < -self.args.turn_angle / 2.:
                action = 2  # Left
            else:
                action = 1  # Forward

        if self.args.deactivate_traphelper == False:
            if self.prev_blocked >= self.block_threshold:
                if self._previous_action == 1:
                    action = self.untrap.get_action()
                    # print("trap helper begins!")
                else:
                    action = 1
        self._previous_action = action

        return action, achieve_goal

    def _get_stg(self, grid, start, goal, planning_window):
        """Get short-term goal"""

        [gx1, gx2, gy1, gy2] = planning_window

        x1, y1, = 0, 0
        x2, y2 = grid.shape

        def add_boundary(mat, value=1):
            h, w = mat.shape
            new_mat = np.zeros((h + 2, w + 2)) + value
            new_mat[1:h + 1, 1:w + 1] = mat
            return new_mat

        traversible = skimage.morphology.binary_dilation(
            grid[x1:x2, y1:y2],
            self.selem) != True
        traversible[self.collision_map[gx1:gx2, gy1:gy2]
                    [x1:x2, y1:y2] == 1] = 0
        traversible[self.visited[gx1:gx2, gy1:gy2][x1:x2, y1:y2] == 1] = 1

        traversible[int(start[0] - x1) - 1:int(start[0] - x1) + 2,
                    int(start[1] - y1) - 1:int(start[1] - y1) + 2] = 1

        traversible = add_boundary(traversible)
        goal = add_boundary(goal, value=0)

        planner = FMMPlanner(traversible)
        selem = skimage.morphology.disk(10)
        goal = skimage.morphology.binary_dilation(
            goal, selem) != True
        goal = 1 - goal * 1.
        planner.set_multi_goal(goal)

        state = [start[0] - x1 + 1, start[1] - y1 + 1]
        stg_x, stg_y, _, stop = planner.get_short_term_goal(state)

        stg_x, stg_y = stg_x + x1 - 1, stg_y + y1 - 1

        return (stg_x, stg_y), stop

    def _preprocess_obs(self, obs, cat_goal_id, use_seg=True):
        args = self.args
        obs = obs.transpose(1, 2, 0)
        rgb = obs[:, :, :3]
        depth = obs[:, :, 3:4]

        sem_seg_pred, sem_seg_entropy, sem_goal_pred = self._get_sem_pred(rgb.astype(np.uint8), depth, cat_goal_id)


        depth = self._preprocess_depth(depth, args.min_depth, args.max_depth)
        depth_ = np.expand_dims(depth, axis=2)
        obs_ = np.concatenate((rgb, depth_), axis=2)

        ds = args.env_frame_width // args.frame_width  # Downscaling factor
        if ds != 1:
            rgb = np.asarray(self.res(rgb.astype(np.uint8)))
            depth = depth[ds // 2::ds, ds // 2::ds]
            sem_seg_pred = sem_seg_pred[ds // 2::ds, ds // 2::ds]
            sem_seg_entropy = sem_seg_entropy[ds // 2::ds, ds // 2::ds]
            sem_goal_pred = sem_goal_pred[ds // 2::ds, ds // 2::ds]


        depth = np.expand_dims(depth, axis=2)
        state = np.concatenate((rgb, depth, sem_seg_pred, sem_seg_entropy[:, :, None], sem_goal_pred[:, :, None]),
                               axis=2).transpose(2, 0, 1) # (28, 120, 160)

        return state, obs_


    def _preprocess_depth(self, depth, min_d, max_d):
        depth = depth[:, :, 0] * 1

        mask1 = depth>0.99 
        mask2 = depth==0
        depth =  depth * (max_d - min_d) * 100 + min_d * 100
        depth[mask1] = 0
        depth[mask2] = 0

        return depth


    def _get_sem_pred(self, rgb, depth, cat_goal):

        semantic_pred, sem_entropy, sem_goal_prob = self.sem_pred_model.get_prediction(rgb, depth, cat_goal)
        self.rgb_vis = cv2.cvtColor(rgb, cv2.COLOR_RGB2BGR)

        return semantic_pred, sem_entropy, sem_goal_prob


    def _write_statistic(self, info):
        args = self.args
        dump_dir = "{}/dump/{}/".format(args.dump_location,
                                        args.exp_name)
        ep_dir = '{}/episodes/thread_{}/eps_{}/'.format(
            dump_dir, self.rank, self.episode_no)
        if not os.path.exists(ep_dir):
            os.makedirs(ep_dir)

        if args.print_images:
            fn = '{}/episodes/thread_{}/eps_{}/{}-{}.txt'.format(
            dump_dir, self.rank, self.episode_no,
            self.rank, self.episode_no)
            f = open(fn, "a")  
            f.write('{0:10} : {1:7f}, {2:10} : {3:7f}, {4:10} : {5:7f} \n'.format("distance_to_goal", info['distance_to_goal'], "softspl", info['softspl'], "success", info['success']) )
            f.close()



    def _visualize_map(self, inputs):
        args = self.args
        
        map_pred = inputs['map_pred']
        exp_pred = inputs['exp_pred']
        start_x, start_y, start_o, gx1, gx2, gy1, gy2 = inputs['pose_pred']

        goal = inputs['goal']
        sem_map = inputs['sem_map_pred']


        gx1, gx2, gy1, gy2 = int(gx1), int(gx2), int(gy1), int(gy2)

        sem_map += 5


        no_cat_mask = sem_map == self.args.num_sem_categories + 4
        
        map_mask = np.rint(map_pred) == 1
        exp_mask = np.rint(exp_pred) == 1
        vis_mask = self.visited_vis[gx1:gx2, gy1:gy2] == 1

        sem_map[no_cat_mask] = 0
        m1 = np.logical_and(no_cat_mask, exp_mask)
        sem_map[m1] = 2

        m2 = np.logical_and(no_cat_mask, map_mask)
        sem_map[m2] = 1

        sem_map[vis_mask] = 3

        selem = skimage.morphology.disk(4)
        goal_mat = 1 - skimage.morphology.binary_dilation(
            goal, selem) != True

        goal_mask = goal_mat == 1
        sem_map[goal_mask] = 4

        color_pal = [int(x * 255.) for x in color_palette]
        sem_map_vis = Image.new("P", (sem_map.shape[1],
                                      sem_map.shape[0]))
        sem_map_vis.putpalette(color_pal)
        sem_map_vis.putdata(sem_map.flatten().astype(np.uint8))
        sem_map_vis = sem_map_vis.convert("RGB")
        sem_map_vis = np.flipud(sem_map_vis)

        sem_map_vis = sem_map_vis[:, :, [2, 1, 0]]
        sem_map_vis = cv2.resize(sem_map_vis, (480, 480),
                                 interpolation=cv2.INTER_NEAREST)
        self.vis_image[50:530, 15:655] = self.rgb_vis
        self.vis_image[50:530, 670:1150] = sem_map_vis

        pos = (
            (start_x * 100. / args.map_resolution - gy1)
            * 480 / map_pred.shape[0],
            (map_pred.shape[1] - start_y * 100. / args.map_resolution + gx1)
            * 480 / map_pred.shape[1],
            np.deg2rad(-start_o)
        )

        agent_arrow = vu.get_contour_points(pos, origin=(670, 50))
        color = (int(color_palette[11] * 255),
                 int(color_palette[10] * 255),
                 int(color_palette[9] * 255))
        cv2.drawContours(self.vis_image, [agent_arrow], 0, color, -1)

        return self.vis_image

    def _visualize_gif(self, map_img, action):
        args = self.args
        dump_dir = "{}/dump/{}/".format(args.dump_location,
                                        args.exp_name)
        ep_dir = '{}/episodes/thread_{}/eps_{}/'.format(
            dump_dir, self.rank, self.episode_no)
        if not os.path.exists(ep_dir):
            os.makedirs(ep_dir)

        if args.visualize:
            # Displaying the image
            cv2.imshow("Thread {}".format(self.rank), map_img)
            cv2.waitKey(1)

        if args.print_images:

            visualize_img_list.append(cv2.cvtColor(map_img,cv2.COLOR_BGR2RGB))
            if self.timestep == 499 or action == 0:
                fn_gif = '{}/episodes/thread_{}/eps_{}/{}-{}-Vis-{}.mp4'.format(
                    dump_dir, self.rank, self.episode_no,
                    self.rank, self.episode_no, self.timestep)
                imageio.mimsave(fn_gif, visualize_img_list)
                visualize_img_list.clear()

    def _visualize_img(self, map_img, action):
        args = self.args
        dump_dir = "{}/dump/{}/".format(args.dump_location,
                                        args.exp_name)
        ep_dir = '{}/episodes/thread_{}/eps_{}/'.format(
            dump_dir, self.rank, self.episode_no)
        if not os.path.exists(ep_dir):
            os.makedirs(ep_dir)

        if args.visualize:
            # Displaying the image
            cv2.imshow("Thread {}".format(self.rank), map_img)
            cv2.waitKey(1)

        if args.print_images:
            fn = '{}/episodes/thread_{}/eps_{}/{}-{}-Vis-{}.png'.format(
                dump_dir, self.rank, self.episode_no,
                self.rank, self.episode_no, self.timestep)
            cv2.imwrite(fn, map_img)